//https://leetcode.cn/problems/minimize-malware-spread/?envType=daily-question&envId=2024-04-16


class Solution {
    vector<int>pa;
    vector<bool>vis;
public:
    int find(int x)
    {
        return x==pa[x]? x:pa[x]=find(pa[x]);
    }
    void Union(int x, int y)
    {
        int px=find(x),py=find(y);
        if(px!=py)
            pa[px]=py;
    }
    void dfs(vector<vector<int>>& graph, int x)
    {
        vis[x]=true;
        for(int i=0;i<graph[x].size();i++)
            if(!vis[i]&&graph[x][i])
            {
                Union(x,i);
                dfs(graph,i);
            }
    }
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n=graph.size();
        pa.resize(n);
        vis.resize(n);
        for(int i=0;i<n;i++)
            pa[i]=i;
        for(int i=0;i<n;i++)
            if(!vis[i])
                dfs(graph,i);
        unordered_map<int,vector<int>>mp;
        for(int i=0;i<n;i++)
            mp[find(i)].push_back(i);
        int res=initial[0],maxn=0;
        unordered_set<int>se(initial.begin(),initial.end());
        for(int& x:initial)
        {
            int px=find(x);
            auto v=mp[px];
            bool flag=false;
            for(int &c:v)
                if(c!=x&&se.find(c)!=se.end())
                {
                    flag=true;
                    break;
                }
            int cur=flag? 0: v.size();
            if(cur>maxn||cur==maxn&&x<res)
            {
                res=x;
                maxn=cur;
            }
        }
        return res;
    }
};